{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fc0db1bc",
   "metadata": {},
   "source": [
    "# Long-Context Reorder\n",
    "\n",
    "No matter the architecture of your model, there is a substantial performance degradation when you include 10+ retrieved documents.\n",
    "In brief: When models must access relevant information in the middle of long contexts, they tend to ignore the provided documents.\n",
    "See: https://arxiv.org/abs/2307.03172\n",
    "\n",
    "To avoid this issue you can re-order documents after retrieval to avoid performance degradation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74d1ebe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install sentence-transformers > /dev/null"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "49cbcd8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Document(page_content='This is a document about the Boston Celtics'),\n",
       " Document(page_content='The Celtics are my favourite team.'),\n",
       " Document(page_content='L. Kornet is one of the best Celtics players.'),\n",
       " Document(page_content='The Boston Celtics won the game by 20 points'),\n",
       " Document(page_content='Larry Bird was an iconic NBA player.'),\n",
       " Document(page_content='Elden Ring is one of the best games in the last 15 years.'),\n",
       " Document(page_content='Basquetball is a great sport.'),\n",
       " Document(page_content='I simply love going to the movies'),\n",
       " Document(page_content='Fly me to the moon is one of my favourite songs.'),\n",
       " Document(page_content='This is just a random text.')]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain.chains import LLMChain, StuffDocumentsChain\n",
    "from langchain.document_transformers import (\n",
    "    LongContextReorder,\n",
    ")\n",
    "from langchain.embeddings import HuggingFaceEmbeddings\n",
    "from langchain.llms import OpenAI\n",
    "from langchain.prompts import PromptTemplate\n",
    "from langchain.vectorstores import Chroma\n",
    "\n",
    "# Get embeddings.\n",
    "embeddings = HuggingFaceEmbeddings(model_name=\"all-MiniLM-L6-v2\")\n",
    "\n",
    "texts = [\n",
    "    \"Basquetball is a great sport.\",\n",
    "    \"Fly me to the moon is one of my favourite songs.\",\n",
    "    \"The Celtics are my favourite team.\",\n",
    "    \"This is a document about the Boston Celtics\",\n",
    "    \"I simply love going to the movies\",\n",
    "    \"The Boston Celtics won the game by 20 points\",\n",
    "    \"This is just a random text.\",\n",
    "    \"Elden Ring is one of the best games in the last 15 years.\",\n",
    "    \"L. Kornet is one of the best Celtics players.\",\n",
    "    \"Larry Bird was an iconic NBA player.\",\n",
    "]\n",
    "\n",
    "# Create a retriever\n",
    "retriever = Chroma.from_texts(texts, embedding=embeddings).as_retriever(\n",
    "    search_kwargs={\"k\": 10}\n",
    ")\n",
    "query = \"What can you tell me about the Celtics?\"\n",
    "\n",
    "# Get relevant documents ordered by relevance score\n",
    "docs = retriever.get_relevant_documents(query)\n",
    "docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "34fb9d6e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Document(page_content='The Celtics are my favourite team.'),\n",
       " Document(page_content='The Boston Celtics won the game by 20 points'),\n",
       " Document(page_content='Elden Ring is one of the best games in the last 15 years.'),\n",
       " Document(page_content='I simply love going to the movies'),\n",
       " Document(page_content='This is just a random text.'),\n",
       " Document(page_content='Fly me to the moon is one of my favourite songs.'),\n",
       " Document(page_content='Basquetball is a great sport.'),\n",
       " Document(page_content='Larry Bird was an iconic NBA player.'),\n",
       " Document(page_content='L. Kornet is one of the best Celtics players.'),\n",
       " Document(page_content='This is a document about the Boston Celtics')]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Reorder the documents:\n",
    "# Less relevant document will be at the middle of the list and more\n",
    "# relevant elements at beginning / end.\n",
    "reordering = LongContextReorder()\n",
    "reordered_docs = reordering.transform_documents(docs)\n",
    "\n",
    "# Confirm that the 4 relevant documents are at beginning and end.\n",
    "reordered_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ceccab87",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'\\n\\nThe Celtics are referenced in four of the nine text extracts. They are mentioned as the favorite team of the author, the winner of a basketball game, a team with one of the best players, and a team with a specific player. Additionally, the last extract states that the document is about the Boston Celtics. This suggests that the Celtics are a basketball team, possibly from Boston, that is well-known and has had successful players and games in the past. '"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We prepare and run a custom Stuff chain with reordered docs as context.\n",
    "\n",
    "# Override prompts\n",
    "document_prompt = PromptTemplate(\n",
    "    input_variables=[\"page_content\"], template=\"{page_content}\"\n",
    ")\n",
    "document_variable_name = \"context\"\n",
    "llm = OpenAI()\n",
    "stuff_prompt_override = \"\"\"Given this text extracts:\n",
    "-----\n",
    "{context}\n",
    "-----\n",
    "Please answer the following question:\n",
    "{query}\"\"\"\n",
    "prompt = PromptTemplate(\n",
    "    template=stuff_prompt_override, input_variables=[\"context\", \"query\"]\n",
    ")\n",
    "\n",
    "# Instantiate the chain\n",
    "llm_chain = LLMChain(llm=llm, prompt=prompt)\n",
    "chain = StuffDocumentsChain(\n",
    "    llm_chain=llm_chain,\n",
    "    document_prompt=document_prompt,\n",
    "    document_variable_name=document_variable_name,\n",
    ")\n",
    "chain.run(input_documents=reordered_docs, query=query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4696a97",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
